from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_core.embeddings import Embeddings
from app.configs.settings import config, ROOT
import torch


def get_embedding_model() -> Embeddings:
    """
    Dynamically loads the embedding model based on the configuration.
    """
    embedding_provider = config.embedding.provider.lower()
    embedding_model_name = config.embedding.name

    if embedding_provider == "openai":
        print("Loading OpenAI embedding model...")
        return OpenAIEmbeddings(model=embedding_model_name)

    elif embedding_provider == "local":
        print(f"Loading local embedding model: {embedding_model_name}")

        # 确定设备：优先使用 CUDA，其次是 MPS (Apple Silicon)，最后是 CPU
        if torch.cuda.is_available():
            device = "cuda"
        elif torch.backends.mps.is_available():
            device = "mps"
        else:
            device = "cpu"

        print(f"Using device: {device}")

        # 使用 HuggingFaceBgeEmbeddings，它对 BGE 模型有优化
        return HuggingFaceBgeEmbeddings(
            model_name=embedding_model_name,
            model_kwargs={"device": device},
            encode_kwargs={"normalize_embeddings": True},  # BGE模型推荐开启归一化
        )
    else:
        raise ValueError(f"Unsupported embedding provider: {embedding_provider}")


# 使用 get_embedding_model() 函数来获取 embedding 实例
embeddings = get_embedding_model()


def get_vector_store() -> Chroma:
    """Initializes and returns the Chroma vector store."""
    persist_directory = str(ROOT / config.vector_store.persist_directory)

    print(f"Initializing vector store at: {persist_directory}")

    vector_store = Chroma(
        collection_name=config.vector_store.collection_name,
        embedding_function=embeddings,
        persist_directory=persist_directory,
    )
    return vector_store